package nsalt.mem

import chisel3._
import chisel3.util._
import chiselFv._


// tobe moved to util packages later.

object HoldUnless {
  def apply[T <: Data](x: T, en: Bool): T = Mux(en, x, RegEnable(x, 0.U.asTypeOf(x), en))
}

object ReadAndHold {
  def apply[T <: Data](x: Mem[T], addr: UInt, en: Bool): T = HoldUnless(x.read(addr), en)
  def apply[T <: Data](x: SyncReadMem[T], addr: UInt, en: Bool): T = HoldUnless(x.read(addr, en), RegNext(en))
}

// Component ports for building requesting / resonding port for sync read/write mem.
// Should not be used independently.

class SyncMemPortA(val sets: Int) extends Bundle {

  val index = Output(UInt(log2Up(sets).W))

  def apply(index: UInt) = {
    this.index := index
    this
  }

}

class SyncMemPortAW[T <: Data](private val t: T, sets: Int, val ways: Int = 1) extends SyncMemPortA(sets) {

  val data = Output(t)
  val waymask = if (ways > 1) Some(Output(UInt(ways.W))) else None

  def apply(data: T, index: UInt, waymask: UInt) : Unit = {
    super.apply(index)
    this.data := data
    this.waymask.map(_ := waymask)
    this
  }

}

class SyncMemPortR[T <: Data](private val t: T, val ways: Int = 1) extends Bundle {
  val data = Output(Vec(ways, t))
}

class SyncMemReadBus[T <: Data](private val t: T, sets: Int, val ways: Int = 1) extends Bundle {

  val req = Decoupled(new SyncMemPortA(sets))
  val res = Flipped(new SyncMemPortR(t, ways))

  def apply(valid: Bool, index: UInt) = {
    this.req.bits.apply(index)
    this.req.valid := valid
    this
  }
}

class SyncMemWriteBus[T <: Data](private val t: T, sets: Int, val ways: Int = 1) extends Bundle {

  val req = Decoupled(new SyncMemPortAW(t, sets, ways))

  def apply(valid: Bool, data: T, index: UInt, waymask: UInt) = {
    this.req.bits.apply(data = data, index = index, waymask = waymask)
    this.req.valid := valid
    this
  }
}

class SyncMem[T <: Data](
  t: T, sets: Int, ways: Int = 1,
  holdRead: Boolean = false, singlePort: Boolean = false
) extends Module {
  
  val io = IO(new Bundle{
    val r = Flipped(new SyncMemReadBus(t, sets, ways))
    val w = Flipped(new SyncMemWriteBus(t, sets, ways))
		val onReset = Output(Bool())
  })

  val wordType = UInt(t.getWidth.W)
  val mem = SyncReadMem(sets, Vec(ways, wordType))

  // Always reset after power-on
  // https://github.com/OSCPU/NutShell/blob/fd86beadfc47f52973270ce6109edebd2a30363b/src/main/scala/utils/SRAMTemplate.scala#L79
  val onReset = RegInit(true.B)
  val (resetAddr, resetDone) = Counter(onReset, sets)
  when (resetDone) {
    onReset := false.B
  }

  val readyW = io.w.req.valid || onReset
  val readyR = if (singlePort)
    !readyW && io.r.req.valid
  else
    io.r.req.valid

  val waymask = Mux(onReset, Fill(ways, "b1".U), io.w.req.bits.waymask.getOrElse("b1".U))

  val addrW = Mux(onReset, resetAddr, io.w.req.bits.index)
  val wordW = Mux(onReset, 0.U.asTypeOf(wordType), io.w.req.bits.data.asUInt)
  val dataW = VecInit(Seq.fill(ways)(wordW))

  when (readyW) {
    mem.write(addrW, dataW, waymask.asBools)
  }
  
  val dataR = (
    if (holdRead)
      ReadAndHold(mem, io.r.req.bits.index, readyR)
    else
      mem.read(io.r.req.bits.index, readyR)
  ).map(_.asTypeOf(t))
    
    // ReadAndHold(mem, io.r.req.bits.index, readyR).map(_.asTypeOf(t))

  io.r.res.data := VecInit(dataR)

  io.r.req.ready := !onReset && (if(singlePort) !readyW else true.B)
  io.w.req.ready := true.B

	io.onReset := onReset
}


class SyncMemArbitrated[T <: Data](nRead: Int, gen: T, set: Int, way: Int = 1) extends Module {
  val io = IO(new Bundle {
    val r = Flipped(Vec(nRead, new SyncMemReadBus(gen, set, way)))
    val w = Flipped(new SyncMemWriteBus(gen, set, way))
  })

  val mem = Module(new SyncMem(gen, set, way, holdRead = false, singlePort = true))
  mem.io.w <> io.w

  val readArb = Module(new Arbiter(chiselTypeOf(io.r(0).req.bits), nRead))
  readArb.io.in <> io.r.map(_.req)
  mem.io.r.req <> readArb.io.out

  // latch read results
  io.r.map{ case r => {
    r.res.data := HoldUnless(mem.io.r.res.data, RegNext(r.req.fire))
  }}
}
